{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import logging\n",
    "import random\n",
    "from typing import Iterable, Optional, Sequence, Tuple\n",
    "\n",
    "import math\n",
    "import numpy as np\n",
    "import torch\n",
    "from reagent.ope.estimators.sequential_estimators import (\n",
    "    DMEstimator,\n",
    "    DoublyRobustEstimator,\n",
    "    EpsilonGreedyRLPolicy,\n",
    "    IPSEstimator,\n",
    "    MAGICEstimator,\n",
    "    NeuralDualDICE,\n",
    "    RandomRLPolicy,\n",
    "    RewardProbability,\n",
    "    RLEstimatorInput,\n",
    "    State,\n",
    "    StateDistribution,\n",
    "    StateReward,\n",
    "    ValueFunction,\n",
    ")\n",
    "from reagent.ope.estimators.types import Action, ActionSpace\n",
    "from reagent.ope.test.envs import Environment, PolicyLogGenerator\n",
    "from reagent.ope.trainers.rl_tabular_trainers import (\n",
    "    DPTrainer,\n",
    "    DPValueFunction,\n",
    "    TabularPolicy,\n",
    ")\n",
    "from reagent.ope.test.gridworld import *\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configurations\n",
    "\n",
    "Alter gamma to affect the discount on the reward. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "GAMMA = 0.9"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Estimators on a Policy\n",
    "\n",
    "Given a dataset of trajectories (episodes) generated by some logging policy, we evaluate the given target policy using 6 popular offline policy estimators for the sequential setting. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_estimators(log, target_policy, value_fun, ground_truth):\n",
    "    estimator_input = RLEstimatorInput(\n",
    "        gamma=GAMMA,\n",
    "        log=log,\n",
    "        target_policy=target_policy,\n",
    "        value_function=value_func,\n",
    "        ground_truth=ground_truth,\n",
    "    )\n",
    "    \n",
    "    dice_results = NeuralDualDICE(state_dim=2, \n",
    "                                  action_dim=4, \n",
    "                                  deterministic_env=True,\n",
    "                                  batch_size=512, \n",
    "                                  training_samples=10000, \n",
    "                                  value_lr = 0.001, \n",
    "                                  zeta_lr = 0.0001, \n",
    "                                  device=device).evaluate(estimator_input)\n",
    "\n",
    "    dm_results = DMEstimator(device=device).evaluate(estimator_input)\n",
    "\n",
    "    ips_results = IPSEstimator(weight_clamper=None, weighted=False, device=device).evaluate(\n",
    "        estimator_input\n",
    "    )\n",
    "    ips_results_weighted = IPSEstimator(weight_clamper=None, weighted=True, device=device).evaluate(\n",
    "        estimator_input\n",
    "    )\n",
    "    dr_results = DoublyRobustEstimator(weight_clamper=None, weighted=False, device=device).evaluate(\n",
    "        estimator_input\n",
    "    )\n",
    "    dr_results_weighted = DoublyRobustEstimator(weight_clamper=None, weighted=True, device=device).evaluate(\n",
    "        estimator_input\n",
    "    )\n",
    "\n",
    "    magic_results = MAGICEstimator(device=device).evaluate(\n",
    "        estimator_input, num_resamples=10, loss_threhold=0.0000001, lr=0.00001\n",
    "    )\n",
    "    \n",
    "    return {\"dm\": dm_results,\n",
    "            \"ips\": ips_results,\n",
    "            \"ips_weighted\": ips_results_weighted,\n",
    "            \"dr\": dr_results,\n",
    "            \"dr_weighted\": dr_results_weighted,\n",
    "            \"magic\": magic_results,\n",
    "            \"dice\": dice_results}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Trajectories, Policies, and Evaluate Estimators\n",
    "\n",
    "We can see that the IPS estimators see good performance for smaller numbers of episodes, but as the number of episodes increases, we see worsening in performance which makes sense as the variance factor is likely increasing. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device - None\n",
      "GridWorld:\n",
      "⭕⬜⬜⬜\n",
      "⬜⬜⬜⬜\n",
      "⬜⬜⬜⬜\n",
      "⬜⬜⬜⭐\n",
      "\n",
      "Opt Policy:\n",
      "⬨⇨⇨⇩\n",
      "⇩⇩⇨⇩\n",
      "⇩⇩⇩⇩\n",
      "⇨⇨⇨⬧\n",
      "\n",
      "Opt state values:\n",
      "  3.27  4.74  6.38   8.2\n",
      "  3.12  4.58   8.2   8.0\n",
      "  4.58   6.2   8.0  10.0\n",
      "   6.2   8.0  10.0   0.0\n",
      "\n",
      "Target Policy ground truth values:\n",
      " 0.299  1.52  3.17  5.86\n",
      "  0.71  1.62  5.43  5.96\n",
      "  2.48  3.98  5.93  9.17\n",
      "  4.35  6.37  9.04   0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "random.seed(1234)\n",
    "np.random.seed(1234)\n",
    "torch.random.manual_seed(1234)\n",
    "\n",
    "logging.basicConfig(level=logging.WARNING)\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else None\n",
    "print(f\"device - {device}\")\n",
    "\n",
    "gridworld = GridWorld.from_grid(\n",
    "    [\n",
    "        [\"s\", \"0\", \"0\", \"0\", \"0\"],\n",
    "        [\"0\", \"0\", \"0\", \"W\", \"0\"],\n",
    "        [\"0\", \"0\", \"0\", \"0\", \"0\"],\n",
    "        [\"0\", \"W\", \"0\", \"0\", \"0\"],\n",
    "        [\"0\", \"0\", \"0\", \"0\", \"g\"],\n",
    "    ],\n",
    "    max_horizon=1000,\n",
    ")\n",
    "print(f\"GridWorld:\\n{gridworld}\")\n",
    "\n",
    "action_space = ActionSpace(4)\n",
    "opt_policy = TabularPolicy(action_space)\n",
    "trainer = DPTrainer(gridworld, opt_policy)\n",
    "value_func = trainer.train(gamma=GAMMA)\n",
    "\n",
    "print(f\"Opt Policy:\\n{gridworld.dump_policy(opt_policy)}\")\n",
    "print(f\"Opt state values:\\n{gridworld.dump_value_func(value_func)}\")\n",
    "\n",
    "behavivor_policy = RandomRLPolicy(action_space)\n",
    "target_policy = EpsilonGreedyRLPolicy(opt_policy, 0.3)\n",
    "model = NoiseGridWorldModel(gridworld, action_space, epsilon=0.3, max_horizon=1000)\n",
    "value_func = DPValueFunction(target_policy, model, GAMMA)\n",
    "ground_truth = DPValueFunction(target_policy, gridworld, GAMMA)\n",
    "\n",
    "print(\n",
    "    f\"Target Policy ground truth values:\\n\"\n",
    "    f\"{gridworld.dump_value_func(ground_truth)}\"\n",
    ")\n",
    "\n",
    "log_generator = PolicyLogGenerator(gridworld, behavivor_policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating estimators on 5-length episodes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0812 115834.496 sequential_estimators.py:742] Data loading time: 12.767301744000001\n",
      "I0812 115835.276 sequential_estimators.py:786] Samples 0, Avg Zeta Loss -0.0024972213432192802, Avg Value Loss 0.014194303192198277,\n",
      "Time per 1000 samples: 16.181496048\n",
      "I0812 120436.815 sequential_estimators.py:786] Samples 1000, Avg Zeta Loss 1.762339136944153, Avg Value Loss -1.3786156352562846,\n",
      "Time per 1000 samples: 7850.563727696001\n",
      "I0812 121012.476 sequential_estimators.py:786] Samples 2000, Avg Zeta Loss 3.8674179503917667, Avg Value Loss -2.7155063413381604,\n",
      "Time per 1000 samples: 7573.734335957\n",
      "I0812 121733.886 sequential_estimators.py:786] Samples 3000, Avg Zeta Loss 5.210317928791042, Avg Value Loss -3.3756288778781895,\n",
      "Time per 1000 samples: 8524.390936679\n"
     ]
    }
   ],
   "source": [
    "result_maps = {}\n",
    "xs = []\n",
    "lengths = [5, 10, 100, 400]\n",
    "# Now evaluate the estimators as the number of episodes increases\n",
    "try:\n",
    "    for length in lengths:\n",
    "        small_log = []\n",
    "        for state in random.sample(list(gridworld.states), 5):\n",
    "            small_log.extend([log_generator.generate_log(state, length) for _ in range(50)])\n",
    "        print(f\"Evaluating estimators on {length}-length episodes\")\n",
    "        results = evaluate_estimators(small_log, target_policy, value_func, ground_truth)\n",
    "        for name, result in results.items():\n",
    "            if not name in result_maps:\n",
    "                result_maps[name] = []\n",
    "            res = result.report()[3].rmse.cpu().numpy()\n",
    "            result_maps[name].append(res)\n",
    "        xs.append(ep)\n",
    "except KeyboardInterrupt:\n",
    "    pass\n",
    "        \n",
    "fig, ax = plt.subplots()\n",
    "for name, results in result_maps.items():\n",
    "    ax.plot(xs, results, label=name)\n",
    "\n",
    "# Log scale vastly improves visualization\n",
    "plt.yscale(\"log\")\n",
    "plt.xscale(\"log\")\n",
    "plt.legend(loc='best')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD3CAYAAAANMK+RAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3deXwU9f348dfslZsAciOCJ4rIIXJVxChtPdpqtfjRqvVotbUePez1/fVbKVq/bf221vartba14K18PLHawzYaPDlF5VBAacIRCEcg5N7dmfn9sZOwCdnsJuxmd3bfz8fDR3ZnZmfeE+J7Zt+fz3w+hm3bCCGEcD9PugMQQgiRHJLQhRAiS0hCF0KILCEJXQghsoQkdCGEyBK+dB24vLxcutcIIUQvzJkzx+hqedoSOpGg4m5TUVFBWVlZn8STKm4/B4k//dx+DhJ/8pSXl8dcJyUXIYTIEpLQhRAiS0hCF0KILJHWGroQQiSDbdvU1tZiWVZK9j9o0CB2796dkn3H4vF4GDhwIIbRZftnlyShCyFcr7a2lqKiIvLz81Oy//z8fEpKSlKy71haWlqora3liCOOSPgzUnIRQrieZVkpS+bpkp+f3+NvHJLQhRAiS2RFQm8JmdQ2BtMdhhBCpFVWJPTyDbv4vyUfpzsMIYSgoaGBMWPGcNlll9Hc3Nynx06oUVQpNR5YDNyjtb5PKfU0MNhZPRBYCvwcWAOscpbv1lpfkrrQD2oOmTS0hvviUEIIkZCnnnqqz48ZN6ErpYqAe4H2502jE7VS6i/AAuftBq11nz8fG7JsGoNmXx9WCCEAOHDgAHPnzqW5uZkzzjgDgDFjxrB27Vpqa2u57rrraGlpYfTo0SxcuJBdu3Zx/fXX09LSgs/n48EHH2TUqFGHHUciJZdW4HyguvMKpdQJwCCt9dLDjuQwhE2L5pAkdCFEejz22GNMmDCBN954g8mTJ3dYN3/+fG655RZef/11hg8fzsqVK5k3bx633nor5eXlfOtb3+LOO+9MShxx79C11mEgrJTqavV3gN9FvR+mlHoBGAL8Xmv9eHf7rqioiBtgQ0ND3O02VNvs3mcntL90SOQcMpnEn35uP4dUxz9o0KAO3RbPfmBF0o/x6g1TY6577733mDVrFvX19UyZMoW2uZrr6+tZvnw58+bNo76+nttuuw2AZcuWsW7dOubPn49pmgwePJj6+vpD9ltTU8O6desSjrHXDxYppQqBTwM3O4v2Aj8FHgOKgOVKqQqt9fZY+0hk9LJERjn76M3NvHdgB2Vlp/f8RPpAJo3U1hsSf/q5/RxSHf/u3bs7PPiz4gdnJ3X/9fX13T5Y5Pf7KSwspKSkhMbGxvanO0tKSvB4PO3r2ng8Hp555hlGjhzZ7XGHDh3K+PHjOyxL1WiLpwNvaq0tInfy9Vrrv2itW7XWtU7j6AmHsf+EhUyLJqmhCyHSZOzYsaxaFekP8tprr3VYd9ppp7FkyRIA5s2bx7/+9S+mT5/O4sWLAXj11Vd58sknkxLH4ST06cDatjdKqdlKqT87rwuACcCGpEQZR8i0aQqa7V9zhBCiL1111VUsXbqUOXPmsGHDhg5PeM6fP58//vGPzJ49m02bNnHWWWcxf/58XnjhBWbPns3tt9/OjBkzkhJHIr1cpgB3A2OAkFJqLnAxMBx4I2rTt4GvKKWWATZwl9b6kIbUVAiZFqZtEzQt8nzevjikEEK069+/f4c78/nz57e/Li4u5t///neH7UeMGMErr7yS9DgSaRRdBXRV/Lqp03Zh4PqkRpegkBm5GjYFTUnoQoiclRVPiobMSKlF+qILIXJZdiR0p17VLAldCJHDsiOhh9tKLvL4vxAid2VHQrciJZcmeVpUCJHDsiOhmxZFAa/0RRdC5LSsSeilBX4apeQihEiztuFz0yFLErpNab5fBugSQmScVE1c3ZWsmCQ6aFoMLAxIyUUIkRZdDZ973HHH8YUvfIEBAwYwb968Pokje+7QC/zSD10IkRZdDZ8bDoc599xz+yyZky136GHTipRcJKELIYD//mZyBruK9j9/+HLMdevXr+fMM88EaP8JMHVq7CF3UyErEnrQtCgt8FFd15LuUIQQGaC75NsbXY1VHs227fYhc6Nr5oFAIKlxxJNVJRdpFBVCpEN3w+f2paxI6GErUnKRRlEhRDp0N3xuX8qikov0QxdCpEd3w+f2pay4Q28vucgduhAih7k+oVu2jWnZ9MvzyVguQoic5vqEHjIt/F6DwoBPauhCiJyWBQndJuD1UCiDcwmRszweDy0t2dVtuaWlBY+nZyna9Y2iIdPC5/WQ5/NgWjZh570QIncMHDiQ2trauP3Fe6umpoahQ4emZN+xeDweBg4c2KPPJJTQlVLjgcXAPVrr+5RSDwFTgL3OJr/SWr+slLoI+CGQD9yrtV7QmxPpicgduoFhGBQEvDSFTPpJQhcipxiGwRFHHJGy/a9bt47x48enbP/JEjehK6WKgHuB8k6r/p/W+qWo7UqAXwOnAiHgXaWU1lo3pCRyR6SGHkngbWWXfvn+VB5SCCEyUiK3sq3A+UB1nO2mAiu01nVa6ybgLeCMJMUZU8i08Dl1pkK/V/qiCyFyVtw7dK11GAgrpTqvukUp9QNgJ3ATMBzYHbV+FzCsu31XVFTEDbChoaHb7aqbbIItNhUVFZitFm8tXcHWYiPufvtSvHPIdBJ/+rn9HCT+vtHbRtFHgTqt9Uql1PeBO4DOAxgYgN3dTsrKyuIeqKKiotvt1u88wD/2bKCsbCqLalYzdvxopo/pWUNCqsU7h0wn8aef289B4k+e8vLO1e+DepXQtdbRe3wZeAB4AhgStXwY8Gpv9t8TobCFzxu5Iy8MeGWALiFEzupVQldKaeBOrfUHwCxgLbAcmKCUKgVMYDrwzeSH3FHIivRDx6mhN0kNXQiRoxLp5TIFuBsYA4SUUnOBecCDSqkmoB74qtY6qJSaB7wOWMAdWuvmVJ9AdC+XooBPZi0SQuSsRBpFVwFdFY+mdbHt08DTSYsuAUHzYMmlIOCVAbqEEDnL9U/ghM2okkvAS6PU0IUQOcr1CT0Y/WCR1NCFEDnM9Qm9bbRF2nq5SMlFCJGjXJ/Qw6aN3yONokII4fqEHl1yKZB+6EKIHOb6hB4y7Q4lF6mhCyFylesTetiSfuhCCEE2JPRg+GCjaIFfGkWFELnL9Qk9ZNlRd+heuUMXQuQs9yf0To2iTSGpoQshclOWJPSDJZfWkIVpdTtqrxBCZKUsSOgHSy4ew4jU0aXrohAiB2VBQj9YckH6ogshclhWJPSA9+CUc5GGUamjCyFyTxYk9IOP/gMUBnw0SU8XIUQOyoKEbuGLLrlIX3QhRI7KioR+aMlFEroQIve4P6FHPViEjOcihMhh7k/o4U4ll4CXJunlIoTIQXHnFCUyUfR4YDFwj9b6PqXUSGAhkAeYwJVa62qlVAh4K+qjc7TWKc2uIcvuWHLxS6OoECI3xU3oSqki4F6gPGrxz4AHtdZaKXUjcCvwfaBOa93VhNIp02U/dEnoQogclEjJpRU4H6iOWvYt4Dnn9R6gX4riiyvSy0X6oQshhGHbiY17opSaD+zRWt8XtcwLvArM11q/ppRqAP4GjASe01rfHWt/5eXlttfrjXvchoYGiouLY67/+fsWN5xoMDAvktTf2WWzvclm7pjMaR6Idw6ZTuJPP7efg8SfPKZpMmfOHKOrdQnV0LviJPNHgQqt9WvO4u8DTwAhYIlS6g2t9fJY+ygri1+dqaio6Ha7u9a/yRmnT2VwcR4ATet20vifvZSVndyb00qJeOeQ6ST+9HP7OUj8yVNeXh5zXa8TutMoWqm1/mnbAq31A22vlVKvAScDMRN6MoRMC7+nY8lFGkWFELmoVwldKXUFYGmtfxy17DjgbuBiZ9FM4JmkRRpDyLTx+zr2Q2+WMdGFEDkokV4uU5xEPQYIKaXmAkOAFqVUhbPZeq31jUqpNcAyIAy8qLVekeoTCJrWIWO5yJOiQohcFDeha61XAQkVj7TWPwF+kpTIEmDbNmHLbp/ggvYnRSWhCyFyT+Z0BemFsGXj8xgYxsGELoNzCSFylasTerDTQ0VIP3QhRA5zdUIPmR0f+ydqxqJE+9cLIUS2cHVCD3caCx3A5/Hg93poDVtpi0sIIdLB1Qk9UnI59IGpQhkTXQiRg1yd0CMll0NPodDvpVnq6EKIHOPyhH5oyQXpiy6EyFGuT+jRj/23KZRJLoQQOcjlCb3rkkuBXx4uEkLkHpcn9K5LLkUyr6gQIge5PqF37oeOU0OXO3QhRK5xd0K37EOeFEVq6EKIHOXuhB6O1ctFSi5CiNzj7oRuxSq5SKOoECL3uDuhmzFKLn6poQshco+rE3rQtPDJHboQQoDbE3q4m0f/pVFUCJFrXJ3QQ52mn2sjjaJCiFzk6oQeu+QiNXQhRO6JO6cokYmixwOLgXu01vcppYYAjwD9gW3AFVrrVqXURcAPgXzgXq31glQGH7PkIv3QhRA5KO4dulKqCLgXKI9a/CtgodZ6BlAJXKGUKgF+DZwLnA78UClVnMrgu5qCDim5CCFyVCIll1bgfKA6alkZ8KLzejFwDjAVWKG1rtNaNwFvAWekKG5oq6HHGD5XJooWQuSauCUXrXUYCCuloheXaK2bnde7gGHAcGB31DZty2OqqKiIG2BDQ0PM7Sq3WDTkG1Q0be6wPGzZNLTaCe2/L3R3Dm4g8aef289B4u8bCdXQuxCMem0Adqdl0ctjKisri3ugioqKmNu9/c+PGDu0hLJJIw9Zd9vq1/jUrNkEfOlv9+3uHNxA4k8/t5+DxJ885eXlMdf1NtvVK6UKndfDnHLMDmBI1DbDOpVpki5kdT2nKNIwKoTIQb29Q/8HcCHwJHAx8DKwHJiglCoFTGA68M0kx9tBrEf/iWoY7V/gT2UIQgiRMeImdKXUFOBuYAwQUkrNBa4AHldK3QpsABZprcNKqXnA64AF3BFVZ0+JWI2iSF90IUQOSqRRdJXTq6WzQ5ZprZ8Gnk5adHFE7tBjlFxkGjohRI5Jf4vhYej+Dl36ogshcov7E7onVqOoTxpFhRA5xeUJvetH/5GSixAiB7k8oXc9BR0yJroQIge5PqF33w9dauhCiNzh7oRudVNykTt0IUSOcXdC77bkIv3QhRC5xfUJXRpFhRAiwuUJvZsHi6QfuhAix7g8ocd59F/6oQshcojLE3q8O3RJ6EKI3OHahG7bdqRR1CM1dCGEwM0J3bRsPIaBN+aj/1JDF0LkFtcm9JBp4/d1ncyRGroQIge5N6FbFv4Y5RaAIqmhCyFyjGsTejAc+7F/gDyfh5BpEbasPo1LCCHSxbUJPWzFnn4OwDAMCvxemuUuXQiRI1yb0IPd9EFvUyR1dCFEDnFtQu/uoaI2BVJHF0LkENcm9HA3DxW1kYeLhBC5JO4k0V1RSn0N+ErUotOAlUAR0Ogs+54zwXRKJFZykb7oQojc0auErrX+C/AXIsn9DOBy4CTgWq312qRH2YXISIvd36EX+GUIXSFE7khGyWU+cGcS9tMjIdOO+dh/m8isRZLQhRC5wbBtu9cfVkpNA27WWl+llKoA9gNHAB8C39ZaN8f6bHl5ue31euMeo6GhgeLi4kOWf7jf5u1dNl87IXZSf6bSYmShwcwh3d/Jp1qsc3ALiT/93H4OEn/ymKbJnDlzukxqvSq5RLkeWOS8/h2wTmu9USl1L3AL8L/dfbisrCzuASoqKrrczt64m832DsrKJsT87HuvbeKIogBl00Ynci4pE+sc3ELiTz+3n4PEnzzl5eUx1x1uQj/TSdxorZ+PWv434LLD3He3Qlbs6efaFMk0dEKIHNLrhK6UOhJo0Vq3KKUM4FXgy1rrncAsIKWNo6E4j/7j9EPfXd+ayjCEECJjHE6j6HCgmsjduQ3cB7yslFoCHO28T5mQZcecT7SNNIoKITLN4bRbxtPrO3St9Qrg3Kj3zwLPJi2yOILhREou0g9dCJEZQqbF39bt5JHlW/jdlyZw5IDCpB/jcGvoaRO2Eii5SD90IUSatYRMXvigmsdWbGHMwEJ+fM5YRvYvSMmxXJvQQ2b8kkuRlFyEEGnS0Brmmfe28eTKbZwyoh93XXgKJw/vl9JjujahB834JRcZnEsI0df2N4d4atVWnnlvOzPGDOT3ahLHDe6bPuyuTehh0yLgS6BRVGroQog+sKehlcdWbOGva3dw1vGDWXjFFEaloE7eHdcm9JBpUxSQfuhCiPSqrmvmkeVb+NdHNZw3bhiPXz2NYf3y0xKLaxN6QiUXv5RchBCpUVnbyENLq3jzkz18ceJInv7qDAYWBdIak2sTeti04462WBjw0hI2sWwbj5He8VyEENlhQ009Dy2rYtXWfajJR/Lc9TPpl+9Pd1jg5oQesuKPh+4xDPJ8XlpCJoUB156qECIDfLC9joVLK/mopp4rph7FbeeemHF5JbOi6YFgAo/+E1V2ybRfvBAi89m2zfKqfTzwkUXThnVcNe0ofnnhePJ88UeKTQfXZrmwZce9Q0f6ogshesG2bV7/ZA8Ll1bR0BpmxiCD71w0I267Xbq5NqEnMgUd0hddCNEDpmVTvmEXC5dW4vEYXDtjDGcdP5g3Xl+S8ckcNyf0kJlYyaUo4KVR+qILIboRMi3+vn4nDy+ron9BgJvPPJZPHX0Ehss6U7g4oSdWcikM+GiWO3QhRBdaQiYvrtnBoyuqOGpAIT/+7ImcOqq/6xJ5Gxcn9ARLLtIXXQjRSWMwzLOrt/PEqq2cPLwfv/jCeMaPKE13WIfN1Qk93uBcSKOoECJKXXOIRe9u5enV25k2egD3zp3E8UMyY67QZHBxQrfxeeJ/LSoM+GQ8FyFy3J6GVp5YuZUX11Rz5vGD+csVUziqj8dZ6QsuTuiJ93JplJKLEDlphzPOyisf1XDuScN4LI3jrPQFVyf0eI/+AxT5vexvDvVJTEKIzFBZ28jDy6p44+M9XDhhBPqrMzgizeOs9AX3JnTLTqhfaEHAS3VdS5/EJIRIr4276lm4tIqVW/ahTs2scVb6Qq8SulKqDHgaWOcsWgP8DHgE6A9sA67QWrcmN9yDEi25SD90IbLfmuo6Fiyt5KOd9Vx+2lH85NwTKcrB4T4O54yXaK3ntr1RSj0MLNRaL1JK/Rq4AliQnDAPFUpgtEXa+qFLLxchso5t26zcso8FS6vYvr+Zr0w7il9ekLnjrPSFZF7CyoAbnNeLgZtTm9AT74cujaJCZA/btnlz814WvlNJXUuYa2aM5ryThrri0fxUM2zb7vGHnJLL/UAVUALcDizSWg901o8F/qS1PjPWPsrLy22vN/6VtKGhgeLiQ/uJ/nCFxS9OM/DGeaJrS4PNC1tsvjUuff/Ysc7BLST+9HP7OSQjfsu2WbMPyqsjOWvOCINTBtAncx1k0u/fNE3mzJnT5Un39g59E3An8BQwGqgAog9gAHGvFGVlZXEPVFFRcch2pmVjrHyNOWedFffzm/c08tKuNZSVzYi7bap0dQ5uIvGnn9vP4XDiD5sWf1u/k4eXbaE038cPzhvDrGP6dpyVTPr9l5eXx1zXq4Sutd4OPOG8/Y9SaicwQilVqLVuAoYB1b0NOJ5Eyy04sxZJyUUI94keZ+XI/oX812dO4LSjBrh2nJW+0NteLpcBJ2qt5yulBgFDgQeBC4EngYuBl5MfbkRPE7oMziWEezQGwzz33nYeX7mVccP68fMvjOeULBhnpS/0tuTyEnCpUuotwAPcCKwGnlRK3QpsABYlOdZ2kZEWE7tKFzqDc9m2LVd2ITJYXXMI/e429OptTM3CcVb6Qm9LLg3ARV2s6pMiU9C08HsSu0P3eT14PQZB08rp7kxCZKq9jUGeWLmFxR9UM/u4wTx4+RRGD8y+cVb6git73ocTnNyiTaEza5EkdCEyx84DLTy6fAv/+HAn55w0lEevmsrw0oJ0h+VqrkzooQTnE23T1jCahYOrCeE6VbVNPLysiiUf7+bCU0aw6NrpDCrOS3dYWcGVCT0YTrxRFKeOLg2jQqTXpl0NLFxWyYqqfVwyeSTPXTeT0oLcGWelL7gyoYetxBtFkTHRhUirtdV1LNxkUbP+PS6fMor/Pic3x1npC678rQZ70G2RtpKLjOciRJ+xbZtVW/ez4J1KtuxrYuYAgz9ePJN8v7RjpZIrE3qoF42iUnIRIvVs2+atzXtZsLSSuuYQV08fzXnjhvHWG69LMu8DLk3oPWsUjQzQJSUXIVLFtGxe3biLhUurALhmxmjmnDAEbwLTRIrkcWlC71nJpUiG0BUiJcKmxd8/rOHhZVWU5Pm4YdYxnHFs346zIg5yb0LvwZW/rR96Mr2/bT/rdh7gy1NGyR+vyDmtYWecleVbGNk/nx9++gSmyjgraefShN67fujJYtk2v/zXBvY3h9jXFOLGM46RP2SRE5qCYZ59r5onVm7hxKEl3Pn5k5kwUsZZyRQuTeg97+WypyGYtOP/88MaCvxe/nDpZG7U72HZNjfPPlaSushaB1pCLHp3G0+v3saUUQP47ZcmMnZoSbrDEp24NqEHepTQfUlrFA2ZFg+8uZl5551E/8IA9186mZv1asKWzXfKjpOkLrJK9DgrZxw3iD99+VTGDCxKd1giBlfO2RSybHw96bboT14N/fn3qxk9sJApowYA0L/Az/2XTmb1tv385rVN9GYGKCEyzc4DLfy6fCNqwVKagiaPXDWVn543TpJ5hnNnQu/po/8Bb1J6uTQFwyxYWslNZxzbYXm/fD+/v2QSa6oP8KvyjZLUhWtt3dfEnf/4kCseXo7f62HRtdP50WfGMkIGzXIFdyZ0yyLQwzv0ZJRcnlq1jSmj+ndZOyzJ93PfJZP4aGc9d/1rI5YkdeEiH+9u4CcvreOrj69icHEez143k2+XHSeDZrmMOxO6afdohu/CgO+wnxTd3xziiVVbuWHWMTG3Kc7z8X+XTOLjPQ384pUNktRFxlu34wDfe/4DbtLvcfzgYp6/fibfmHUM/WXQLFdyaULvxXjoh1lyeXhZFZ8+YQij4ozBW5zn43dzJ1JV28id//gI05KkLjKLbdus2rKPm/RqfrR4DdNGD2Dx12dy9fTRFOe5sp+EcLjyXy9k2j3s5XJ4/dBr6lv465pqnrx2ekLbFwV8/O5Lk/juc+/zs398yOwCSeoi/Wzb5u3/7GXBO1XsawpyzYzIOCs9aY8Smc2VCT1oWvgSnIKOJAzO9ee3/sMXJ45kcA/qiQUBL7/90kRufe4DntphM/vMnsUsRLJYts1rG3ezcGklpmVz7YwxzBkr46xko14ndKXUz4GzAD9wF/A5YAqw19nkV1rrl5MX6kE9nYIu4PVg2jZh0+pR7R2gcm8jSz7ew7PXzehxnPl+L7+5eAJfXbCEn768nts/N06SuugzYdPinx/W8NCyKoryfFx/+tGccewgPPKsRNbqVUJXSs0GJmmtZyqlBgIfAP8G/p/W+qXkh9lRT0suhmE4PV1MSgt6llD/8OZmrpx6FP3ye9dIlO/3cu3xBn+tDXPbS+v52efG9fiiIkRPtIZNXlq7k0eWVzG8Xz7fn3MC00bLOCu5oLeZ5W1AOa/3A4G+bGDt6QQX9LIv+rodB1hTXcelpx7Zwwg78nsMfvXFU2gJmfz4r+sImdZh7U+IrjQHTR5fsYWL/vwOb3yyhzs+dzIPXHYq08cMlGSeI4zDfQhGKfV14FPO2+FAPrATuElrvSfW58rLy22vN/6A9w0NDRQXF3dY9tAmiymDDE4ZkPgf6a/XWFx5nMGwgsQ/88cNFhMGGMwccnj/M7SdQ9iyefQTGwO48lgDn0tqmF39G7iJ2+Mnzjk0h23e2gVv1tgcWwJnDzcYWZRZf1tu/zfIpPhN02TOnDld/gMfVqOoUupC4HrgM079vE5rvVIp9X3gDuDG7j5fVlYW9xgVFRWHbPfC3veZPGEks44dlHCsD21byfiJxzN+RGIjwy2rrKV10wZ+8KXph10iiT6HM8+0+PGLa3l5v81dF5xCwJf55Zeu/g3cxO3xE+McahuDPLFqKy98uJ1Zxw5i4TmjOfqIzHw03+3/BpkUf3l5ecx1h9Moeg4wD/is1no/EH2Ul4EHervveII9fPSf9qdFEyu52LbN71//hBvOOCbp9W6/18MvLhjPf7+0jh8uXsNdF44nzydTc4nE1dS38OjyLfx9/U4+c+JQHrlqqjyaL6C3dW+lVCnwG+B8rfVeZ5lWSk1wNpkFrE1qpFHCVs96udDDSS5e3bgby7b59NghvYywez6vh//5/MkU+L18//k1tMhsSiIBW/c1cec/P+Tyh5bj8xgsunY6/yXjrIgovb1DvxQYACxSqq1tlHnAg0qpJqAe+GrywuyopxNc0N4oGn88l7Blcf8bm/n+nONT2r3L5/Xws8+PY/7fPuR7z3/A3RdNkEl0RZc+3t3AE59YbF67irmTRvLsdTPl0XzRpV4ldK31n4A/dbFq2uGHFF/vern4Eiq5vLR2J0NKAswYM/AwIkyMz+Ph9vPHcfvf1/Pd5z7gNxdNoCAgSV1ErN95gIXvVPJB9QGmDzT4zZUz5dF80a3Mb5HrQuQOvYclF3/8p0VbQiZ/fvs/3HRG380+5PUY/PS8cQwtyeM7z71PU5Im4hDu9e7Wfdzy9Hv84IU1TDkqMs7K2cMNSeYiLlf+hfR0CjraxnOJU6t+evV2xg3rl3BPmGTxegzmnXcSP//nR3z7mff57dyJFAVc+U8jeikyzkotDy2tZE9jkKunj+Y3J8s4K6JnXJk1QqaFv4d9uAsDPnYcaI65vr4lxKMrqnjg0lOTEGHPeQyDH59zIr98ZQPfevp9fjd3otyR5QDLtqnYuJuFy6oImhbXTh/Np08cIkNEiF5xZcbodaNoNyWXx1ZsYdYxgzhmUPr68XoMg//67Fj+998bueXp97j3kkmS1LNU2LJ4xRlnJd/v5WszxzD7OBlnRRweV2aLXpdcYiT0PQ2tPPvedh67uk/adLvlMQx+9OkTuPvVTdykV3PfJZMo6eU4MiLzBMMWL63dwcPLqxjWL5/vnS3jrFco0kkAAA24SURBVIjkcWdCt6weDc5FnImiFyyt5HMnD2dYv/wkRXh4DMPge2cfzz2vfcxNOnKnXird1FytOWjy/AfbeXzFVo4bXMwd549j4pH90x2WyDKuLNT1qpdLjH7o2/Y388pHu7hmxugkRnj4DMPgu2cdx6mj+nOjXs3+5lC6QxK9UN8SYsE7lXzxz2/z/vY67r54Ar+bO1GSuUgJ1yV0y7axLLvHg/MXxeiH/sc3N3PpqUcyoDCQxCiTwzAMvl12HDPGDOTGRavZ1xRMd0giQfuagvz+9U+46M/vUFXbxAOXncpdF57CiV1MMC5Esriu5NJWP+9pzbGgi0bRjbvqWV61j//67NgkR5k8hmFw8+xj8XoMvrloNferyQwsyryLj4ioqW/hsRVb+Nu6nXx67BAe+spUjuwvj+aLvuHChN7zcgsxBuf6wxubuXbG6Izv820YBt+cdQxew+CGRau5X01iUIzp8BobWtm9ow7DY1BYlEdBYYCCogBe6c+cUtv2NfHw8i28unEXnz95OE9eM50hJYlPWShEMmR2JutCqBfTyOGUXJqiauirt+3nkz2N3HXhKUmOMDUMw+Abs47B49yp//aC8YQPNFNTXceuHXXtP0NBkyHD+2EYBk2NrTQ1BmlpCuIP+CgsClBQlEdhUSDyujCPwuIABYWBSPIvivxs266g0I9H+kN365M9DTy0tIp3Kmv50sQRPPu1GfTPwPKdyA0uTOg2gV7coef7PQTDFqZl4zHg969/wjdOPzrjxyNvbQmxa+cBdlXXUVO9H/+OOk6qrOW+dysZPrI/I4/sz9ARpYwdP4Ihw0vp17/gkHKUZdm0toRobgzS1NhKc1OwPdk3NwbZt7eR6i37IusagzQ561ubQwTyfBgei/VL/+lcCCKJv6Aw6sLQfpGIrMsvCOBxyeQdvfXhzgMsWFrFB9v3c9mUUfzoM2PlmQGRdq77C+zNwFw4d7gF/sg0dKu37ae+Ncy544alJMbeCAXD7K5pS9x11OyoY1d1HQ0HWhg0tIShI/ozZEQpM8YO5YLhpbz48R5eXLOD+784KW53S4/HiJReCgMMHJz4rCuWZdPSHKTi1Tc4ZfykyAWgKUhTQ+SisHd3PdsqO14cmhpbCbaGySvwU9jpzj9yIQhQWJwXWVec1+HbQV6+P+MvBKu37WfBO5Vs3tPIldOO4mefGyejZIqM4bqEHjatXj8WXeD30tAa5vevf8KNZxzT454yyRAOm+ytqW9P2m2lkrp9TRwxuJghw0sZOqKUKZ86hqEj+jNwUFGXZY9rBhXj83q44al3eeCyU1PSh97j1OELS3yMOjrx2aEsy6K5KUSzk+jb7/ydbwe7dxygualtXbB9u1AwTH5hoFOyd8o/beWgwoBTJspr/5mX70vpgzm2bfNOZS0L36lkd0MrV08fzd0XTcj4b3ci97guoe+qrmPMR9UsXbKJCVOOojBG42BXCgM+nv+gmsKAl9k9mL6uN0zTonZ3AzXVdXyyrp4dm95k1446anc30P+IIoYOL2XIiP5MnDqaoSNKOWJISY8bLq+cehQew+AbT73LHy6dnDETHXg8HoqK8yjqwb8Nzu+sue1bQPtd/8GLQs32/e3fEJqaDq4Lh8yO9f+otoGCojy2bW9kTckW5+JwsFwUyOv+QmDZNhWbdvPQ0ipawhbXzhjNZ2ScFZHBXJfQCwYUUj9qIFUf7+KVF97nmLFDmDTtaE4YP5xAnN4qRQEvjy3fwr2XTEzaHZ1l2ezb29BeKmlroNy7q56S0gKGjCjFMm3GTTySsvNOZtDQfviT+BX98tNG4fMY3PDUau6/dLKru8h5vR6K++VT3MNvG+GQSXNTsEOyj/5WUL8vzNp3t0RdCCLLLctu/xZQWJRHIM+Hx+vB8BjsbgyyeW8TXq/B9KElHDmgEHPNdl5ZX43H68Hr8eDxGpHXXgOPx4PX68HjiVrmvO9qeYdlzs+O+4r8jCw3sCwb27ZliADRLdcl9BOGlXD7NdMZXlpAS3OQtau3snTJRvTCtwnk+ehXWkBJaQH9+rf9LIws619AkWUxdUgxo/N91GzfTyhsEg6ahMMmoaBJOGwRCoY7/YysDwfNQ7avP9DCnpoDFBQGGDqilCHDSznupGF8as5YhgwrJeA0klVUVDBx2piU/U7UqUfiMeAG50591IDClB0rE/n8Xkqcf/euVFTsp6xs1iHLQyGT5vbEH6SxKcjyyr1UbNhFab6P2TPHcPSAQizLxjItLMvGjP5p2oRaw7RaFqZpY7X9NC0ss9Myy3KW25iW89PZl9Vpn6bz+Q7HC5uUP/1U+0Ug1oXi0OXOhccTvU2nC4dzMYn+bPvnoi440fv1eDtt6+m4rPP2DXUhdu880OV+OxxPvv0cFtcl9Dyfl+HO/7j5BQFO+9SxnPapY7Esm+bGVg7UNXNgfzP1dc0cqGumpno/H3+4gwP7mxm6pwG/x+ChDTvw+T34/D78fg8+vxef34s/6mfba5/fS2FhAF+psy7gxeeL/CwsymPI8H7kF6S/m9rcyUfi8bT1U5/M6IHpSeqWbWNazn/OU72R12BaNpZtE3Z+tm93yLa2sy0dtrUsm3CM7Tq8j3ptWjabt1msWfJx19tGbffu1v0cO7iI714xlckZ9mh+RUUFZ555ZvsFIPoiYbZdFMzOF5wuLhSHXECi99V2ATp4YTJDFiFnWfQFpsvjtcfm7Csqtvr6Bja++/ohx+78WZySXeeL0CEXL+fbi7eLC8whF6EuLnbRF6F437K8Xg8125r58IPtHS4+CX3Lir6AOtun8luW6xJ6LB6PQVFJPkUl+Qw/ckC6w0mLiyeOxOcxuHHRak4/5ohDkmfnBNn2X/v7GAmyodHino1vRz4blXzDXWyLM2GHz2PgMQy8HgOvEVnm7bDMeR/1OrKeDp/1GJF9dfis8xmPh4776mLfPsPA7zEoyffhNTx4nc8csq1hcOXUoxibwY/mG4aB10kkbhuqraKigrKysrjbWV18e4l5wbKiLhydLl5W54uQmdi3rI4XrIMXu5qaZloPfNztt6yOrw89Xtt7j8fgW7edz+Bh/ZL+e056QldK3QHMAfKBb2itVyb7GCK2C04ZwbB++Wzd14zH6JjwDkmkHRIaeD2e9uQbnTxXrljOzBmTD02sUdv6nH1m4njeFaEqyqanruQlksfj8eDxRMpomSRyQTrzsPdjOzdDqbpLT2pCV0qdBUzVWp+ulBoP3A/MTuYxRHzTRg9kWhIHj6zKN1zd2CpEpmj7lpUqyW6BOAtYDKC1XguMUErlVgudEEKkSbIT+nBgd9T73cDQJB9DCCFEF5JdQ+88YLcB2LE2rqioiLvDhoaGhLbLZG4/B4k//dx+DhJ/30h2Qt8BDIl6PxioibVxIq3eibaOZzK3n4PEn35uPweJP3nKy8tjrkt2yeXvwIVEGkhPBTZrrZuTfAwhhBBdSGpC11qvAt5XSr0LPADcmsz9CyGEiC3p/dC11j8CfpTs/QohhOieDJwghBBZwrDtmJ1QUqq8vDw9BxZCCJebM2dOl08npS2hCyGESC4puQghRJaQhC6EEFlCEroQQmQJSehCCJElJKELIUSWyOgZi9w6WYYzFvxi4B6t9X1KqSHAI0B/YBtwhda6Nd1xxqKU+rkzFLIfuAtY4pb4neGaH3JG+SwC7gCWuiX+NkqpAmCdE//f3BS/UqoMeNqJH2AN8DOXncPlwPecAQZvA1a4If6MvUOPniwDuBr4TbpjSoRSqgi4F4geQedXwEKt9QygErgijSF2Syk1G5iktZ4JfBa4x03xAxcAK7XWZwJfAn7tsvjb/ATY67x2Y/xLtNZlzn+3uOkclFLFTjI/Hfg88EW3xJ+xCd3Fk2W0AucD1VHLyoAXndeLgXPSFFsi3gaU83o/EADOdkv8WuuntNb/67w90rmbctPvH6XUicBJwMvOIlfFH4ObzuEc4GWtdYvWulprfb1b4s/kkstw4P2o922TZfwnjTHFpbUOA2GlVPTikqhRJ3cBw9ITXXxO/A3O2+ucr/sXuCX+NkqpZU6c5wNvuCz+XwM3A9c4713z9xNlnFLq70AJcLvLzmEUMNiJvxj4qVviz+Q79B5NlpHhos/FFeehlLoQuB74jhvj11pPBy4CngLCUasyOn6l1FXA61rryqjFbvv9bwLuBD4HfAV40Im7TaafQ55z8/h54KtOm4wr/oYyOaH3aLKMDFcfVS4a1qkck3GUUucA84Bztdb73RS/Uuo0pdRRRJL6u87feKNb4neS4Fyl1FLnG9JtQLOL4kdrvV1r/YTW2tJa/wfYCRS76Bx2Au9orU2t9SbggFv+hjI5oWfTZBn/aDsX4OKo2mjGUUqVOg3Q52ut2xrlXBM/8CnnWwVKqaHOV/6X3BK/1vpSrfU0p/HtQad3iGviJ/J7v0wpNd95Pci5233QRefwb+BspZTh9FBzzd9QRg/OpZS6C/iM83Xna1rrNemOKR6l1BTgbmAMEAK2Oy3ijzvd6DYA1zi16oyjlPo6MB/YGLX4auBhl8SfByx06qB5Tre/VcCTbog/mpMUK4F/uil+p5fIo843bI9zUVrtsnP4OnB5VBvACjfEn9EJXQghROIyueQihBCiByShCyFElpCELoQQWUISuhBCZAlJ6EIIkSUkoQshRJaQhC6EEFni/wMP+PTk2IAC0AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "bento_obj_id": "139882812866128",
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots()\n",
    "for name, results in result_maps.items():\n",
    "    ax.plot(xs, results, label=name)\n",
    "\n",
    "# Log scale vastly improves visualization\n",
    "#plt.yscale(\"log\")\n",
    "#plt.xscale(\"log\")\n",
    "plt.legend(loc='best')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "disseminate_notebook_info": {},
  "kernelspec": {
   "display_name": "reagent (local)",
   "language": "python",
   "name": "reinforcement_learning_local"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
